clc;
clearvars;
load Data;
format short g


%=========== Define simulation parameters ===========%
max_iter = 100;
market_size = 100;
cluster = 1;
text_cluster = num2str(cluster);
text_size = num2str(market_size);
text_iter = num2str(max_iter);


%=========== Define variables to store results ===========%
share = cell(max_iter,1);
DivorceCosts = cell(max_iter,1);
HHIndices = cell(max_iter,1);
StabilityIndices = cell(max_iter,1);


if cluster == 0
    text_solver = 'gurobi';
    parnum = 2;
else
    text_solver = 'mosek';
    parnum = 20;
    addpath(genpath("/opt/apps/easybuild/software/YALMIP/R20230609/"))
    addpath(genpath("~/mosek/10.1/toolbox/r2017a"))
end


%=========== Define chores = housework + childcare hours ===========%
Data.chorehm = Data.hworkm + Data.chcarem;
Data.chorehf = Data.hworkf + Data.chcaref;
Data.chorehm(Data.singlefemale == 1) = 0;
Data.chorehf(Data.singlemale == 1) = 0;


%=========== Define povertyline ===========%
Data.consumption = Data.q + Data.Q;
Data.percapitaconsumption(find(Data.couple == 1)) = Data.consumption(find(Data.couple == 1))/2;
Data.percapitaconsumption(find(Data.couple == 0)) = Data.consumption(find(Data.couple == 0));
povertyline = 0.6*median(Data.percapitaconsumption);


%=========== Define age restrictions for marriage market ===========%
age_diff = Data.agem-Data.agef;
age_diff = age_diff(Data.couple==1);
lb = prctile(age_diff,2.5);
ub = prctile(age_diff,97.5);


%===========  Define age, education and children categories for preference types ===========%
for i = 1:size(Data,1)
    if isnan(Data.agem(i))
        Data.agecatm(i) = 0;
    elseif 1 < Data.agem(i) &&  Data.agem(i) <= 35
        Data.agecatm(i) = 1;
    elseif 35 < Data.agem(i) &&  Data.agem(i) <= 50
        Data.agecatm(i) = 2;
    elseif 50 < Data.agem(i)
        Data.agecatm(i) = 3;
    end

    if isnan(Data.agef(i))
        Data.agecatf(i) = 0;
    elseif 1 < Data.agef(i) &&  Data.agef(i) <= 35
        Data.agecatf(i) = 1;
    elseif 35 < Data.agef(i) &&  Data.agef(i) <= 50
        Data.agecatf(i) = 2;
    elseif 50 < Data.agef(i)
        Data.agecatf(i) = 3;
    end

    if isnan(Data.edum(i))
        Data.educatm(i) = 0;
    else
        Data.educatm(i) = Data.edum(i);
    end

    if isnan(Data.eduf(i))
        Data.educatf(i) = 0;
    else
        Data.educatf(i) = Data.eduf(i);
    end

    if isnan(Data.nchild(i))
        Data.kidscat(i) = 0;
    elseif Data.nchild(i) == 0
        Data.kidscat(i) = 1;
    elseif Data.nchild(i) > 0
        Data.kidscat(i) = 2;
    end

    Data.typem(i) = Data.educatm(i)*100 + Data.kidscat(i)*10 + Data.agecatm(i);
    Data.typef(i) = Data.educatf(i)*100 + Data.kidscat(i)*10 + Data.agecatf(i);

    if Data.singlemale(i) == 1
        Data.typef(i) = 0;
    elseif Data.singlefemale(i) == 1
        Data.typem(i) = 0;
    end
end


%=========== Define household types and weights ===========%
Data.types = Data.typem.*1000 + Data.typef;
[C,~,ic] = unique(Data.types);
a_counts = accumarray(ic,1);
value_counts = [C, a_counts];
value_counts(:,3) = value_counts(:,2)./sum(value_counts(:,2));


%=========== Define couple type based on education and employment status of spouses ===========%
for i = 1:size(Data,1)
    if Data.couple(i) == 1
        Data.hcatm(i) = 10*(1+Data.male_emp(i)) + Data.educatm(i);
        Data.hcatf(i) = 10*(1+Data.female_emp(i)) + Data.educatf(i);
        Data.hcat(i) = Data.hcatm(i)*100 + Data.hcatf(i);
    elseif Data.singlemale(i) == 1
        Data.hcat(i) = 0;
    elseif Data.singlefemale(i) == 1
        Data.hcat(i) = 0;
    end
end



%%
%=========== Estimate the model with subsamples ===========% 
sc = parallel.pool.Constant(RandStream('Threefry'));


parfor (iter = 1:max_iter,parnum)

    stream = sc.Value;
    stream.Substream = iter;

    complete = 0;
    while complete == 0

        %=========== randomly draw household types from their weighted distribution ===========% 
        rand_index = randsample(stream,value_counts(:,1), market_size, true, value_counts(:,3));
        [C,~,ic] = unique(rand_index);
        a_counts = accumarray(ic,1);
        index_counts = [C, a_counts];

        %=========== Given the number of household types required, randomly draw households of that type from the observed households ===========%
        Data_r = [];
        for i = 1:size(index_counts,1)
            Extract = Data(find(Data.types == index_counts(i,1)),:);
            Extract_index = randsample(stream,size(Extract,1),index_counts(i,2),true);
            Data_r = [Data_r; Extract(Extract_index,:)];
        end

       %=========== Define consideration set ===========% 
       consideration_set = get_consideration_sets(Data_r.couple,Data_r.singlemale,Data_r.singlefemale,Data_r.agem,Data_r.agef,lb,ub);
       
       %=========== Compute stability indices ===========% 
        fprintf('Computing Divorce Costs...\n');
        [temp_DivorceCostsM,temp_DivorceCostsF,temp_DivorceCostsMF,solved] = get_indices(Data_r.couple,Data_r.singlemale,Data_r.singlefemale,...
            Data_r.leisurem,Data_r.leisuref,Data_r.chorehm,Data_r.chorehf,Data_r.wagem,Data_r.wagef,Data_r.q,Data_r.Q,...
            Data_r.nonlabor,consideration_set,text_solver);

        %=========== Compute CEBs ===========% 
        if solved == 0
            fprintf('Computing CEBs...\n');
            [temp_min_male,temp_max_male,temp_min_female,temp_max_female, problem] = get_ceb(Data_r.couple,Data_r.singlemale,Data_r.singlefemale,...
                Data_r.leisurem,Data_r.leisuref,Data_r.chorehm,Data_r.chorehf,Data_r.wagem,Data_r.wagef,Data_r.q,Data_r.Q,...
                Data_r.nonlabor,consideration_set,temp_DivorceCostsM+eps,temp_DivorceCostsF+eps,temp_DivorceCostsMF+eps,text_solver);

            if sum(~isnan(temp_min_male(:))) && sum(~isnan(temp_max_male(:))) && sum(~isnan(temp_min_female(:))) && sum(~isnan(temp_max_female(:))) && problem == 0

                %=========== Store stability indices ===========%     
                ncouple = length(find(Data_r.couple));
                temp_DivorceCostsMF(~consideration_set) = NaN;
                DC_temp = temp_DivorceCostsMF;

                DivorceCosts_maxm = max(DC_temp,[],2);
                DivorceCosts_maxm = DivorceCosts_maxm(find(Data_r.couple));
                DivorceCosts_maxf = max(DC_temp)';
                DivorceCosts_maxf = DivorceCosts_maxf(find(Data_r.couple));
                DivorceCosts_avgf = (nansum(DC_temp)')./(sum(~isnan(DC_temp))');
                DivorceCosts_avgf = DivorceCosts_avgf(find(Data_r.couple));
                DivorceCosts_avgm = (nansum(DC_temp,2))./(sum(~isnan(DC_temp),2));
                DivorceCosts_avgm = DivorceCosts_avgm(find(Data_r.couple));

                DivorceCosts_temp = zeros(ncouple,9);
                DivorceCosts_temp(:,1) = (temp_DivorceCostsM(find(Data_r.couple)))./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,2) = (temp_DivorceCostsF(find(Data_r.couple)))./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,3) = DivorceCosts_maxm./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,4) = DivorceCosts_maxf./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,5) = DivorceCosts_avgm./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,6) = DivorceCosts_avgf./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,7) = iter;
                mmsize = nansum(consideration_set,2);
                DivorceCosts_temp(:,8) = mmsize(find(Data_r.couple));
                mmsize = nansum(consideration_set)';
                DivorceCosts_temp(:,9) = mmsize(find(Data_r.couple));

                NBPStable = DivorceCosts_temp(:,5) <= 0.00001 & DivorceCosts_temp(:,6) <= 0.00001;
                IRMStable = DivorceCosts_temp(:,1) <= 0.00001;
                IRFStable = DivorceCosts_temp(:,2) <= 0.00001;                
                DivorceCosts{iter} = [mean(DivorceCosts_temp), mean(NBPStable), mean(IRMStable), mean(IRFStable)];

                temp_DivorceCostsM(find(~Data_r.couple)) = NaN;
                temp_DivorceCostsF(find(~Data_r.couple)) = NaN;
                DC_temp = [DC_temp, temp_DivorceCostsM];
                DC_temp = [DC_temp; [temp_DivorceCostsF',0]];
                StabilityIndices{iter} = DC_temp;
                HHIndices{iter} = Data_r.familyid;

                %=========== Store RICEBs ===========%  
                qF_temp = zeros(length(Data_r.couple),19);
                qF_temp(:,1) = temp_min_male';
                qF_temp(:,2) = temp_max_male';
                qF_temp(:,3) = temp_min_female';
                qF_temp(:,4) = temp_max_female';
                qF_temp(:,5) = Data_r.employcatm;
                qF_temp(:,6) = Data_r.employcatf;
                qF_temp(:,7) = Data_r.educatm;
                qF_temp(:,8) = Data_r.educatf;
                qF_temp(:,9) = Data_r.leisurem;
                qF_temp(:,10) = Data_r.leisuref;
                qF_temp(:,11) = Data_r.chorehm;
                qF_temp(:,12) = Data_r.chorehf;
                qF_temp(:,13) = Data_r.q;
                qF_temp(:,14) = Data_r.Q;
                qF_temp(:,15) = Data_r.couple;
                qF_temp(:,16) = Data_r.singlemale;
                qF_temp(:,17) = Data_r.singlefemale;   
                qF_temp(:,18) = Data_r.familyid;   
                qF_temp(:,19) = iter;
                share{iter} = qF_temp;

                complete = 1;
                fprintf('*********** Iteration %d finished *************\n', iter);
            end

        end
    end
end

filename = ['DataPSID_results_cebs_individual_size',text_size,'_iter',text_iter,'_Cluster',text_cluster];
save(filename)

DivorceCosts_combine = [];
share_combine = [];
for i=1:max_iter
    share_combine = [share_combine; share{i}];
    DivorceCosts_combine = [DivorceCosts_combine; DivorceCosts{i}];
    
end

results = share_combine;
results = array2table(results, 'VariableNames', {'sharem_min','sharem_max','sharef_min','sharef_max', 'empym','empyf','edum','eduf','leisurem','leisuref','hworkm','hworkf','q','Q','couple','singlemale','singlefemale','familyid','iter'});
writetable(results,filename);



